//
//  mnistTrain.cpp
//  MNN
//
//  Created by MNN on 2019/11/27.
//  Copyright © 2018, Alibaba Group Holding Limited
//

#include <MNN/expr/Executor.hpp>
#include <MNN/expr/Optimizer.hpp>
#include <cmath>
#include <iostream>
#include <vector>
#include "DataLoader.hpp"
#include "MnistDataset.hpp"
#include "DemoUnit.hpp"
#include "NN.hpp"
#include "SGD.hpp"
#define MNN_OPEN_TIME_TRACE
#include <MNN/AutoTime.hpp>
#include "ADAM.hpp"
#include "LearningRateScheduler.hpp"
#include "Loss.hpp"
#include "RandomGenerator.hpp"
#include "Transformer.hpp"

using namespace MNN::Train;
using namespace MNN::Express;
class MnistV2 : public Module {
public:
    MnistV2() {
        NN::ConvOption convOption;
        convOption.kernelSize = {5, 5};
        convOption.channel    = {1, 10};
        convOption.depthwise  = false;
        conv1                 = NN::Conv(convOption);
        convOption.reset();
        convOption.kernelSize = {5, 5};
        convOption.channel    = {10, 10};
        convOption.depthwise  = true;
        conv2                 = NN::Conv(convOption);
        ip1                   = NN::Linear(160, 100);
        ip2                   = NN::Linear(100, 10);
        registerModel({conv1, conv2, ip1, ip2});
    }

    virtual std::vector<VARP> onForward(const std::vector<VARP>& inputs) override {
        VARP x = inputs[0];
        x      = conv1->forward(x);
        x      = _MaxPool(x, {2, 2}, {2, 2});
        x      = conv2->forward(x);
        x      = _MaxPool(x, {2, 2}, {2, 2});
        x      = _Convert(x, NCHW);
        x      = _Reshape(x, {0, -1});
        x      = ip1->forward(x);
        x      = _Relu(x);
        x      = ip2->forward(x);
        x      = _Softmax(x, 1);
        return {x};
    }
    std::shared_ptr<Module> conv1;
    std::shared_ptr<Module> conv2;
    std::shared_ptr<Module> ip1;
    std::shared_ptr<Module> ip2;
};
class Mnist : public Module {
public:
    Mnist() {
        NN::ConvOption convOption;
        convOption.kernelSize = {5, 5};
        convOption.channel    = {1, 20};
        conv1                 = NN::Conv(convOption);
        convOption.reset();
        convOption.kernelSize = {5, 5};
        convOption.channel    = {20, 50};
        conv2                 = NN::Conv(convOption);
        ip1                   = NN::Linear(800, 500);
        ip2                   = NN::Linear(500, 10);
        dropout               = NN::Dropout(0.5);
        registerModel({conv1, conv2, ip1, ip2, dropout});
        AUTOTIME;
    }

    virtual std::vector<VARP> onForward(const std::vector<VARP>& inputs) override {
        VARP x = inputs[0];
        x      = conv1->forward(x);
        x      = _MaxPool(x, {2, 2}, {2, 2});
        x      = conv2->forward(x);
        x      = _MaxPool(x, {2, 2}, {2, 2});
        x      = _Convert(x, NCHW);
        x      = _Reshape(x, {0, -1});
        x      = ip1->forward(x);
        x      = _Relu(x);
        x      = dropout->forward(x);
        x      = ip2->forward(x);
        x      = _Softmax(x, 1);
        return {x};
    }
    std::shared_ptr<Module> conv1;
    std::shared_ptr<Module> conv2;
    std::shared_ptr<Module> ip1;
    std::shared_ptr<Module> ip2;
    std::shared_ptr<Module> dropout;
};

static void train(std::shared_ptr<Module> model, std::string root) {
    auto exe = Executor::getGlobalExecutor();
    BackendConfig config;
    exe->setGlobalExecutorConfig(MNN_FORWARD_CPU, config, 2);
    std::shared_ptr<SGD> sgd(new SGD);
    sgd->append(model->parameters());
    sgd->setMomentum(0.9f);
    // sgd->setMomentum2(0.99f);
    sgd->setWeightDecay(0.0005f);

    auto dataset = std::make_shared<MnistDataset>(root, MnistDataset::Mode::TRAIN);
    // the stack transform, stack [1, 28, 28] to [n, 1, 28, 28]
    auto transform = std::make_shared<StackTransform>();

    const size_t batchSize  = 64;
    const size_t numWorkers = 4;
    bool shuffle            = true;

    auto dataLoader = DataLoader::makeDataLoader(dataset, {transform}, batchSize, shuffle, numWorkers);

    const size_t iterations = dataset->size() / batchSize;

    auto testDataset            = std::make_shared<MnistDataset>(root, MnistDataset::Mode::TEST);
    const size_t testBatchSize  = 20;
    const size_t testNumWorkers = 1;
    shuffle                     = false;

    auto testDataLoader = DataLoader::makeDataLoader(testDataset, {transform}, testBatchSize, shuffle, testNumWorkers);

    const size_t testIterations = testDataset->size() / testBatchSize;

    for (int epoch = 0; epoch < 50; ++epoch) {
        exe->gc();
        int correct = 0;
        testDataLoader->reset();
        model->setIsTraining(false);
        for (int i = 0; i < testIterations; i++) {
            if ((i + 1) % 100 == 0) {
                std::cout << "test iteration: " << (i + 1) << std::endl;
            }
            auto data       = testDataLoader->next();
            auto example    = data[0];
            auto cast       = _Cast<float>(example.data[0]);
            example.data[0] = cast * _Const(1.0f / 255.0f);
            auto predict    = model->forward(example.data[0]);
            predict         = _ArgMax(predict, 1);
            auto accu       = _Cast<int32_t>(_Equal(predict, _Cast<int32_t>(example.target[0]))).sum({});
            correct += accu->readMap<int32_t>()[0];
        }
        auto accu = (float)correct / (float)testDataset->size();
        std::cout << "epoch: " << epoch << "  accuracy: " << accu << std::endl;

        dataLoader->reset();
        AUTOTIME;
        model->setIsTraining(true);
        for (int i = 0; i < iterations; i++) {
            // AUTOTIME;
            auto trainData  = dataLoader->next();
            auto example    = trainData[0];
            auto cast       = _Cast<float>(example.data[0]);
            example.data[0] = cast * _Const(1.0f / 255.0f);

            // Compute One-Hot
            auto newTarget = _OneHot(_Cast<int32_t>(example.target[0]), _Scalar<int>(10), _Scalar<float>(1.0f),
                                     _Scalar<float>(0.0f));

            auto predict = model->forward(example.data[0]);
            auto loss    = _CrossEntropy(predict, newTarget);
            float rate   = LrScheduler::inv(0.01, epoch * iterations + i, 0.0001, 0.75);
            sgd->setLearningRate(rate);
            if ((epoch * iterations + i) % 100 == 0) {
                std::cout << "train iteration: " << epoch * iterations + i;
                std::cout << " loss: " << loss->readMap<float>()[0];
                std::cout << " lr: " << rate << std::endl;
            }
            sgd->step(loss);
            if (i == iterations - 1) {
                model->setIsTraining(false);
                predict = model->forward(_Input({1, 1, 28, 28}, NCHW));
                Variable::save({predict}, "temp.mnist.mnn");
            }
        }
    }
}

class MnistTrain : public DemoUnit {
public:
    virtual int run(int argc, const char* argv[]) override {
        if (argc < 2) {
            std::cout << "usage: ./runTrainDemo.out MnistTrain /path/to/unzipped/mnist/data/  [depthwise]" << std::endl;
            return 0;
        }
        // global random number generator, should invoke before construct the model and dataset
        RandomGenerator::generator(17);

        auto exe = Executor::getGlobalExecutor();
        BackendConfig config;
        exe->setGlobalExecutorConfig(MNN_FORWARD_CPU, config, 2);

        std::string root = argv[1];
        std::shared_ptr<Module> model(new Mnist);
        if (argc >= 3) {
            model.reset(new MnistV2);
        }
        train(model, root);
        return 0;
    }
};

class PostTrain : public DemoUnit {
public:
    virtual int run(int argc, const char* argv[]) override {
        if (argc < 3) {
            std::cout << "usage: ./runTrainDemo.out PostTrain /path/to/mnistModel /path/to/unzipped/mnist/data/ "
                      << std::endl;
            return 0;
        }
        std::string root = argv[2];
        auto varMap      = Variable::loadMap(argv[1]);
        if (varMap.empty()) {
            MNN_ERROR("Can not load model %s\n", argv[1]);
            return 0;
        }
        auto inputOutputs = Variable::getInputAndOutput(varMap);
        Transformer::turnModelToTrainable(Transformer::TrainConfig())
            ->onExecute(Variable::mapToSequence(inputOutputs.second));
        std::shared_ptr<Module> model(Module::transform(Variable::mapToSequence(inputOutputs.first),
                                                        (Variable::mapToSequence(inputOutputs.second))));

        train(model, root);
        return 0;
    }
};

DemoUnitSetRegister(MnistTrain, "MnistTrain");
DemoUnitSetRegister(PostTrain, "PostTrain");
